#!/usr/bin/env python
# $URL$
# $Rev$

# pipstack
# Combine input PNG files into a multi-channel output PNG.

"""
pipstack file1.png [file2.png ...]

pipstack can be used to combine 3 greyscale PNG files into a colour, RGB,
PNG file.  In fact it is slightly more general than that.  The number of
channels in the output PNG is equal to the sum of the numbers of
channels in the input images.  It is an error if this sum exceeds 4 (the
maximum number of channels in a PNG image is 4, for an RGBA image).  The
output colour model corresponds to the number of channels: 1 -
greyscale; 2 - greyscale+alpha; 3 - RGB; 4 - RGB+alpha.

In this way it is possible to combine 3 greyscale PNG files into an RGB
PNG (a common expected use) as well as more esoteric options: rgb.png +
grey.png = rgba.png; grey.png + grey.png = greyalpha.png.

Color Profile, Gamma, and so on.

[This is not implemented yet]

If an input has an ICC Profile (``iCCP`` chunk) then the output will
have an ICC Profile, but only if it is possible to combine all the input
ICC Profiles.  It is possible to combine all the input ICC Profiles
only when: they all use the same Profile Connection Space; the PCS white
point is the same (specified in the header; should always be D50);
possibly some other things I haven't thought of yet.

If some of the inputs have a ``gAMA`` chunk (specifying gamma) and
an output ICC Profile is being generated, then the gamma information
will be incorporated into the ICC Profile.

When the output is an RGB colour type and the output ICC Profile is
synthesized, it is necessary to supply colorant tags (``rXYZ`` and so
on).  These are taken from ``sRGB``.

If the input images have ``gAMA`` chunks and no input image has an ICC
Profile then the output image will have a ``gAMA`` chunk, but only if
all the ``gAMA`` chunks specify the same value.  Otherwise a warning
will be emitted and no ``gAMA`` chunk.  It is possible to add or replace
a ``gAMA`` chunk using the ``pipchunk`` tool.

gAMA, pHYs, iCCP, sRGB, tIME, any other chunks.
"""

class Error(Exception):
    pass

def stack(out, inp):
    """Stack the input PNG files into a single output PNG."""

    from array import array
    import itertools
    # Local module
    import png

    if len(inp) < 1:
        raise Error("Required input is missing.")

    l = map(png.Reader, inp)
    # Let data be a list of (pixel,info) pairs.
    data = map(lambda p: p.asDirect()[2:], l)
    totalchannels = sum(map(lambda x: x[1]['planes'], data))

    if not (0 < totalchannels <= 4):
        raise Error("Too many channels in input.")
    alpha = totalchannels in (2,4)
    greyscale = totalchannels in (1,2)
    bitdepth = []
    for b in map(lambda x: x[1]['bitdepth'], data):
        try:
            if b == int(b):
                bitdepth.append(b)
                continue
        except (TypeError, ValueError):
            pass
        # Assume a tuple.
        bitdepth += b
    # Currently, fail unless all bitdepths equal.
    if len(set(bitdepth)) > 1:
        raise NotImplemented("Cannot cope when bitdepths differ - sorry!")
    bitdepth = bitdepth[0]
    arraytype = 'BH'[bitdepth > 8]
    size = map(lambda x: x[1]['size'], data)
    # Currently, fail unless all images same size.
    if len(set(size)) > 1:
        raise NotImplemented("Cannot cope when sizes differ - sorry!")
    size = size[0]
    # Values per row
    vpr = totalchannels * size[0]
    def iterstack():
        # the izip call creates an iterator that yields the next row
        # from all the input images combined into a tuple.
        for irow in itertools.izip(*map(lambda x: x[0], data)):
            row = array(arraytype, [0]*vpr)
            # output channel
            och = 0
            for i,arow in enumerate(irow):
                # ensure incoming row is an array
                arow = array(arraytype, arow)
                n = data[i][1]['planes']
                for j in range(n):
                    row[och::totalchannels] = arow[j::n]
                    och += 1
            yield row
    w = png.Writer(size[0], size[1],
      greyscale=greyscale, alpha=alpha, bitdepth=bitdepth)
    w.write(out, iterstack())


def main(argv=None):
    import sys

    if argv is None:
        argv = sys.argv
    argv = argv[1:]
    arg = argv[:]
    return stack(sys.stdout, arg)


if __name__ == '__main__':
    main()
